06. 训练迭代 / Epoch 的次数

迭代次数

训练迭代次数是一个超参数,我们可以使用一种叫做“早期停止”(或“早期终止”)的技术自动优化。

ValidationMonitor

在 tensorflow 中,我们可以使用 ValidationMonitor 与 tf.contrib.learn发挥两个功能:监督训练过程和在满足特定条件的情况下停止训练。

来自 ValidationMonitor 文档的以下示例展示了它的设置。注意最后三个参数表示我们正在优化的指标。

validation_monitor = tf.contrib.learn.monitors.ValidationMonitor(
  test_set.data,
  test_set.target,
  every_n_steps=50,
  metrics=validation_metrics,
  early_stopping_metric="loss",
  early_stopping_metric_minimize=True,
  early_stopping_rounds=200)

最后一个参数向 ValidationMonitor 表示如果损失未在 200 步(轮)训练内降低,则停止训练过程。

然后,validation_monitor 被传递给 tf.contrib.learn 的 "fit" 方法,后者运行以下训练过程:

classifier = tf.contrib.learn.DNNClassifier(
  feature_columns=feature_columns,
  hidden_units=[10, 20, 10],
  n_classes=3,
  model_dir="/tmp/iris_model",
  config=tf.contrib.learn.RunConfig(save_checkpoints_secs=1))

classifier.fit(x=training_set.data,
           y=training_set.target,
           steps=2000,
           monitors=[validation_monitor])

SessionRunHook

最近版本的 TensorFlow 废弃了 Monitor 函数,而采用 SessionRunHooks 。SessionRunHook 是 tf.train 不断发展的一部分,往后似乎将是实施早期停止的一个适当位置。

到本文写作之时,tf.train 的训练钩子函数 中已存在两个预定义的停止 Monitor 函数。

  • StopAtStepHook:用于在特定步数之后要求停止训练的 Monitor 函数
  • NanTensorHook:监控损失并在遇到 NaN 损失时停止训练的 Monitor 函数